SageMaker Sparkを使ったXGBoostでの手書き数字分類をやってみた

SageMaker Sparkを使ったXGBoostでの手書き数字分類をやってみた

Clock Icon2018.09.21

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

SageMaker Sparkを使用して、Amazon SageMaker上で手書き数字の分類モデルを作成するノートブックをやってみたので、その内容を紹介します。

概要

SageMaker Sparkを使って、Sparkをローカル上に展開し、MNISTのデータを読み込み、SageMaker上でXGBoostの分類モデルの学習とホストを行います。その後、作成したモデルを使ってテストデータの分類を行います。

SageMaker Spark
Apache SparkのSageMaker用のオープンソースのライブラリです。Sparkに読み込んだデータを使ったSageMaker上での学習や、SparkのMLlibとSageMakerを連携させること等が出来ます。
MNISTのデータセット
手書き数字画像とラベルデータを含んだデータセットです。データ分析のチュートリアルでよく使われるデータセットの一つです。
XGBoost
勾配ブースティングツリーという理論のオープンソースの実装で、分類や回帰によく使われる機械学習アルゴリズムの一つです。

基本的にはawslabsのノートブックに沿って進めますが、一部変更している箇所があります。

やってみた

ノートブックの作成

SageMakerのノートブックインスタンスを立ち上げて、 SageMaker Examples

Sagemaker Spark

pyspark_mnist_xgboost.ipynb

use
でサンプルからノートブックをコピーして、開きます。 ノートブックインスタンスの作成についてはこちらをご参照ください。

セットアップ

IAMロールの取得とSparkセッションの構築を行います。 今回はローカル上にSparkアプリケーションを実行し、セッションに繋ぎます。 リモートのSparkクラスタと接続する場合はAmazon SageMaker PySparkのGitHubでの解説をご参照ください。

import os

from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession

import sagemaker
from sagemaker import get_execution_role
import sagemaker_pyspark

# 実行しているIAMロールを取得する
role = get_execution_role()

# SageMaker Sparkが依存するjarを取得する
jars = sagemaker_pyspark.classpath_jars()

classpath = ":".join(sagemaker_pyspark.classpath_jars())


# ローカルで実行するSparkアプリケーションとのセッションを取得する
spark = SparkSession.builder.config("spark.driver.extraClassPath", classpath)\
    .master("local[*]").getOrCreate()

データの読み込み

MNISTのデータセットをSpark Dataframeに読み込みます。 データセットはLibSVM形式のものがS3上で公開されています。 LibSVM形式についてはこちらのエントリをご覧ください。

学習と推論のデータセットと使用するには2種類のカラムが必要です。一つはダブル型のカラム(デフォルトではlabelという名前)、二つ目はダブル型のベクトル(デフォルトではfeaturesという名前)です。

import boto3

region = boto3.Session().region_name

trainingData = spark.read.format('libsvm')\
    .option('numFeatures', '784')\
    .option('vectorType', 'dense')\
    .load('s3a://sagemaker-sample-data-{}/spark/mnist/train/'.format(region))

testData = spark.read.format('libsvm')\
    .option('numFeatures', '784')\
    .option('vectorType', 'dense')\
    .load('s3a://sagemaker-sample-data-{}/spark/mnist/test/'.format(region))

trainingData.show()

モデルの学習とホスティング

Estimatorを作成し、学習とモデルをホストするエンドポイントの作成を行います。 Estimatorには学習とエンドポイントのパラメータ、学習時のハイパーパラメータを設定します。 XGBoostのハイパーパラメータについてはドキュメントをご参照ください。

import random
from sagemaker_pyspark import IAMRole, S3DataPath
from sagemaker_pyspark.algorithms import XGBoostSageMakerEstimator
from sagemaker_pyspark.S3Resources import S3DataPath

# モデルアーティファクトの出力先を指定(ノートブックにはなく、追加した内容です。)
output_s3_data = S3DataPath('bucket_name', 'sagemaker/spark-xgb/output')

# estimatorの設定
xgboost_estimator = XGBoostSageMakerEstimator(
    sagemakerRole=IAMRole(role), # 学習時とエンドポイントの作成時に使用するIAMロール
    trainingInstanceType='ml.m4.xlarge', # 学習に使用するインスタンスタイプ
    trainingInstanceCount=1, # 学習に使用するインスタンス数
    endpointInstanceType='ml.m4.xlarge', # モデルをホストするエンドポイントのインスタンスタイプ
    endpointInitialInstanceCount=1, # モデルをホストするエンドポイントのインスタンス数
    trainingOutputS3DataPath = output_s3_data) # モデルアーティファクトの出力先

# ハイパーパラメータの設定
xgboost_estimator.setEta(0.2)
xgboost_estimator.setGamma(4)
xgboost_estimator.setMinChildWeight(6)
xgboost_estimator.setSilent(0)
xgboost_estimator.setObjective("multi:softmax")
xgboost_estimator.setNumClasses(10)
xgboost_estimator.setNumRound(10)

# 学習処理開始
model = xgboost_estimator.fit(trainingData)

推論

テストデータをまとめて分類(推論)します。この処理を行うために内部的には、featureカラムをLibSVM形式に変換して、モデルをホストするエンドポイントに投げます。その後、エンドポイントでモデルが分類した結果をCSV形式で受け取って、DataFrameに格納します。この処理をtransformがまとめてやってくれます。便利です。
※ ノートブックではtrainDataを読み込んでいますが、testDataの方が適切だと思われるので、testDataを読み込むように変更しています。

transformedData = model.transform(testData)

transformedData.show()

入力データに加えてpredictionカラムが追加されています。

分類結果ごとに入力画像を表示してみて、正しく分類できたかを確認します。

from pyspark.sql.types import DoubleType
import matplotlib.pyplot as plt
import numpy as np

# 数字を表示するための補助関数
def show_digit(img, caption='', xlabel='', subplot=None):
    if subplot==None:
        _,(subplot)=plt.subplots(1,1)
    imgr=img.reshape((28,28))
    subplot.axes.get_xaxis().set_ticks([])
    subplot.axes.get_yaxis().set_ticks([])
    plt.title(caption)
    plt.xlabel(xlabel)
    subplot.imshow(imgr, cmap='gray')

# 入力画像データを取得
images = np.array(transformedData.select("features").cache().take(250))
# 分類結果を取得
clusters = transformedData.select("prediction").cache().take(250)

# 分類結果ごとに入力画像を表示する
for cluster in range(10):
    print('\n\n\nCluster {}:'.format(int(cluster)))
    digits=[ img for l, img in zip(clusters, images) if int(l.prediction) == cluster ]
    height=((len(digits) - 1) // 5) + 1
    width=5
    plt.rcParams["figure.figsize"] = (width,height)
    _, subplots = plt.subplots(height, width)
    subplots=np.ndarray.flatten(subplots)
    for subplot, image in zip(subplots, digits):
        show_digit(image, subplot=subplot)
    for subplot in subplots[len(digits):]:
        subplot.axis('off')

    plt.show()

実際には0~9までの数字に対して結果が表示されますが、ここでは0と9の結果だけ紹介します。 0に分類された入力画像は全て正しく0を表していそうです。

9に分類された入力画像は基本的に正しそうですが、7や8が紛れ込んでいます。

概ね高い精度で分類出来ていそうです。

エンドポイントの削除

余計な費用が掛からないようにエンドポイントを削除します。

from sagemaker_pyspark import SageMakerResourceCleanup

resource_cleanup = SageMakerResourceCleanup(model.sagemakerClient)
resource_cleanup.deleteResources(model.getCreatedResources())

おわりに

SageMaker Sparkライブラリを使用して、Amazon SageMaker上で手書き数字の分類を行えました。SageMaker Sparkを使うことで、Amazon SageMakerのモデルをSpark上で扱うことができます。大きなデータを扱ったり、Apache SparkのMLlibとAmazon SageMakerを連携させる際には非常に便利そうです。

これからAmazon SageMakerとApache Sparkを試してみたいと思っている方の参考になれば幸いです。 最後までお読み頂き有難うございましたー!

参考

Share this article

facebook logohatena logotwitter logo

© Classmethod, Inc. All rights reserved.